import json
import random

import numpy as np
import pandas as pd
import torch

from Asia_Modular_Training.asia_graph import set_asia_graph
from ModularUtils.ControllerConstants import generate_permutations, map_dictfill_to_discrete
from ModularUtils.ControllerModel import get_fake_distribution, get_generated_labels
from ModularUtils.DigitImageGeneration.mnist_image_generation import plot_trained_digits
from ModularUtils.Experiment_Class import Experiment
from ModularUtils.FrontBackDoorCalculation import estiamte_ate_frontdoor_direct, estiamte_ate_backdoor_direct
from ModularUtils.FunctionsConstant import asKey, getdoKey
from ModularUtils.FunctionsDistribution import get_joint_distributions_from_samples, calculate_TVD, calculate_KL
from ModularUtils.FunctionsTraining import save_results
from ModularUtils.Functions_Plot_Results import plot_saved_results


def asiaEvaluation(Exp, cur_hnodes, label_generators, dataset_dict, tvd_diff, kl_diff):
    for gen in label_generators:
        label_generators[gen].eval()

    with torch.no_grad():
        # observational tvd for each mechanisms so that we can notice that mechanism learning

        feat = "feature"
        all_generated_labels={}
        all_real_labels={}


        # for hn, cur_mechs in cur_hnodes.items():
        # for compare_Var in [["C"], ["D", "C"]]:


        for query in Exp.interv_queries:

            for key in query["intervs"]:
                compare_Var= query["obs"]
                # for interv_no, key in enumerate(Exp.Data_intervs):

                if key=={}:
                    # continue

                    if len(compare_Var)==0:
                        continue

                    intv_key = asKey(key)

                    obs_indices = [Exp.label_names.index(lb) for lb in compare_Var]
                    current_real_label = []
                    if intv_key in dataset_dict:
                        current_real_label = dataset_dict[intv_key]["obs"][:, obs_indices].type(torch.LongTensor).view(-1,len(obs_indices)).to(Exp.DEVICE)

                    query_str = getdoKey(compare_Var, dict(intv_key))

                    fake_dist_dict= get_fake_distribution(Exp, label_generators, intv_key, compare_Var)
                    dataset_dist_dict = get_joint_distributions_from_samples(Exp, compare_Var,
                                                                             current_real_label.detach().cpu().numpy().astype(
                                                                                 int), "feature")

                    obs_tvd = calculate_TVD(fake_dist_dict, dataset_dist_dict, doPrint=False)
                    obs_kl= calculate_KL(fake_dist_dict, dataset_dist_dict, doPrint=False)

                    if query_str in tvd_diff:
                        tvd_diff[query_str].append(round(obs_tvd, 4))  # todo: fix it
                        kl_diff[query_str].append(round(obs_kl, 4))


            # for key in query["intervs"]:
            #     compare_Var= query["obs"]
                # for interv_no, key in enumerate(Exp.Data_intervs):

                # if key=={}:
                #     continue.

                if query["expr"] == 'P(dysp|do(lung))':
                    fake_dist_dict = get_fake_distribution(Exp, label_generators, key, compare_Var)
                    fake_dist_dict= { list(key)[0]: val for key, val in fake_dist_dict.items()}
                    # fate_x.append(list(fake_dist.values())[1])

                    obs_indices = [Exp.label_names.index(lb) for lb in ["lung", "either", "dysp"]]
                    current_real_label = dataset_dict[asKey({})]["obs"][:, obs_indices].detach().cpu().numpy()

                    dimlist = [Exp.label_dim[lb] for lb in ["lung", "either", "dysp"]]
                    all_comb = generate_permutations(dimlist)
                    current_real_label = np.concatenate((current_real_label, all_comb), axis=0)

                    px = pd.DataFrame(current_real_label)
                    px = px.rename(columns={0: 'lung', 1: 'either', 2: 'dysp'})
                    dataset_dist_dict =estiamte_ate_frontdoor_direct(Exp, px, 'lung', 'dysp', ['either'])[list(key.values())[0]]
                    obs_tvd = calculate_TVD(fake_dist_dict, dataset_dist_dict, doPrint=False)
                    obs_kl = calculate_KL(fake_dist_dict, dataset_dist_dict, doPrint=False)
                    query_str = getdoKey(compare_Var, dict(key))

                    if query_str in tvd_diff:
                        tvd_diff[query_str].append(round(obs_tvd, 4))  # todo: fix it
                        kl_diff[query_str].append(round(obs_kl, 4))

                if query["expr"] == "P(dysp|do(either))":
                    fake_dist_dict = get_fake_distribution(Exp, label_generators, key, compare_Var)
                    fake_dist_dict= { list(key)[0]: val for key, val in fake_dist_dict.items()}

                    # fate_x.append(list(fake_dist.values())[1])

                    obs_indices = [Exp.label_names.index(lb) for lb in ["lung", "either", "dysp"]]
                    current_real_label = dataset_dict[asKey({})]["obs"][:, obs_indices].detach().cpu().numpy()

                    dimlist = [Exp.label_dim[lb] for lb in ["lung", "either", "dysp"]]
                    all_comb = generate_permutations(dimlist)
                    current_real_label = np.concatenate((current_real_label, all_comb), axis=0)

                    px = pd.DataFrame(current_real_label)
                    px = px.rename(columns={0: 'lung', 1: 'either', 2: 'dysp'})
                    dataset_dist_dict=estiamte_ate_backdoor_direct(Exp, px, 'either', 'dysp', ['lung'])[list(key.values())[0]]

                    obs_tvd = calculate_TVD(fake_dist_dict, dataset_dist_dict, doPrint=False)
                    obs_kl = calculate_KL(fake_dist_dict, dataset_dist_dict, doPrint=False)
                    # query_str = query["expr"]
                    query_str = getdoKey(compare_Var, dict(key))


                    if query_str in tvd_diff:
                        tvd_diff[query_str].append(round(obs_tvd, 4))  # todo: fix it
                        kl_diff[query_str].append(round(obs_kl, 4))


            # if query["expr"] == "P(dysp|do(lung))":
            #     fate_diff = fate_x[1] - fate_x[0]
            #     rate_diff = rate_x[1] - rate_x[0]
            #     tvd_diff["P(dysp|do(lung))"].append(abs(fate_diff-rate_diff))
            #
            # if query["expr"] == "P(dysp|do(either))":
            #     fate_diff = fate_x[1] - fate_x[0]
            #     rate_diff = rate_x[1] - rate_x[0]
            #     tvd_diff["P(dysp|do(either))"].append(abs(fate_diff - rate_diff))



        # compare_Var = [lb for lb in Exp.label_names if lb not in Exp.image_labels + Exp.rep_labels]
        # query_str = getdoKey(compare_Var, {})
        # tvd_diff[query_str].append(round(obs_tvd, 4))  # todo: fix it
        # kl_diff[query_str].append(round(obs_kl, 4))

        # if (Exp.curr_epoochs + 1) % 1 == 0:
        # tvd_diff, kl_diff, _, _ = get_observational_loss(Exp, compare_Var , label_generators, tvd_diff, kl_diff)
        # tvd_diff, kl_diff, _, _ = get_expected_loss_interventions(Exp, cur_mechs,  label_generators, tvd_diff, kl_diff)
        # tvd_diff, kl_diff= get_expected_loss_countefactuals(Exp, cur_mechs,  label_generators, tvd_diff, kl_diff)



        # dimlist= [Exp.label_dim[lb] for lb in Exp.label_names]
        # all_comb= generate_permutations(dimlist)
        # generated_labels_full= np.concatenate((generated_labels_full, all_comb), axis=0)
        # px = pd.DataFrame(generated_labels_full)
        # px = px.rename(columns={0: 'D', 1: 'I', 2: 'T', 3:'C'})
        # ATE1 = estiamte_ate_frontdoor_direct(Exp, px, 'D', 'C', ['T'])
        #
        # ATE2 = estiamte_ate_frontdoor_direct(Exp, px, 'D', 'C', ['I'])
        #
        # diff= sum([abs(a - b) for a, b in zip(ATE1, ATE2)])
        # print("frontdoor difference:",diff)


        # for img in Exp.image_labels:



        save_results(Exp, Exp.SAVED_PATH, all_generated_labels ,all_real_labels,
                     tvd_diff, kl_diff, Exp.G_avg_losses, Exp.D_avg_losses)



    for gen in label_generators:
        label_generators[gen].train()

    ll = -min(10, len(list(tvd_diff.values())[0]))
    # printing loss
    for dist in tvd_diff:
        print("###", dist, " loss%:",  [round(val, 4) for val in tvd_diff[dist][ll:]])
    print(Exp.SAVED_PATH)

    return tvd_diff , kl_diff




Exp = Experiment("Exp1", set_asia_graph,
                 new_experiment=False,
                 features=["feature"],
                 Data_intervs=[{}])



root = f"root_path"
exp='exp_date'
bnc_exp=[]
# pre_labels= ['$P(D,A)$', '$ncmP(D,A)$', 'rep$P(D,A)$',
#              'P(A|do(D=0))', 'ncmP(A|do(D=0))', 'repP(A|do(D=0))',
#              'P(A|do(D=1))', 'ncmP(A|do(D=1))', 'repP(A|do(D=1))']

last_exp= f"{root}/{exp}"
benchmarks=[]

plot_saved_results(Exp, last_exp, epochs=1000, delta=10,
               pre_labels=None, benchmarks=benchmarks)  #only whatifgan